Skip to content

Conversation

@NanoCode012
Copy link
Collaborator

@NanoCode012 NanoCode012 commented Sep 2, 2025

Description

Adds support for linkedin/Liger-Kernel#860

Enable via liger_use_token_scaling: true

  • Requires upstream release from Liger, >0.6.2

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Added optional setting to enable token scaling for fused linear cross-entropy. When turned on with the existing FLCE option, each token’s loss is scaled by its predicted probability (detached), and this behavior is enforced automatically during training.
  • Documentation

    • Updated usage example to include the new FLCE-specific option, demonstrating how to enable token scaling alongside fused linear cross-entropy.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 2, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Introduces an optional configuration liger_use_token_scaling for Liger’s fused linear cross-entropy (FLCE). Updates README usage. Adds the field to LigerArgs. In plugin pre_model_load, when both FLCE and token scaling are enabled, runtime patches force use_token_scaling=True for both the FLCE function and loss class initializer.

Changes

Cohort / File(s) Summary
Docs: Liger usage update
src/axolotl/integrations/liger/README.md
Adds FLCE-specific option example: liger_use_token_scaling: true under a new section, without altering existing options.
Config schema: LigerArgs
src/axolotl/integrations/liger/args.py
Adds optional bool field liger_use_token_scaling (default None) with description; imports Pydantic Field; no other logic changes.
Plugin: FLCE token-scaling patching
src/axolotl/integrations/liger/plugin.py
If liger_fused_linear_cross_entropy and liger_use_token_scaling are true, wraps functional.liger_fused_linear_cross_entropy and LigerFusedLinearCrossEntropyLoss.__init__ to inject use_token_scaling=True on every call; other branches unchanged.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • djsaunde
✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/liger-dft

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions
Copy link
Contributor

github-actions bot commented Sep 2, 2025

📖 Documentation Preview: https://68b68ca04d39d4dd3d02e0fc--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 0b2795f

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
src/axolotl/integrations/liger/README.md (1)

21-23: Clarify version/feature dependency for token scaling.

Readers need to know this only works with FLCE and a Liger-Kernel build that includes the token-scaling feature (PR #860+).

 # FLCE-specific
-liger_use_token_scaling: true
+# Requires Liger-Kernel with token-scaling support (PR #860+) and FLCE enabled
+liger_use_token_scaling: true
src/axolotl/integrations/liger/args.py (1)

38-46: Guard misuse and clarify description.

Warn when token scaling is set without FLCE, and note the dependency in the field description.

     liger_use_token_scaling: bool | None = Field(
         default=None,
         json_schema_extra={
             "description": (
-                "Enables use_token_scaling in fused_linear_cross_entropy. "
-                "When True, each token's loss is multiplied by its predicted probability (detached from gradients)."
+                "Enables use_token_scaling in fused_linear_cross_entropy (FLCE). "
+                "When True, each token's loss is multiplied by its predicted probability (detached from gradients). "
+                "Requires `liger_fused_linear_cross_entropy: true` and a Liger-Kernel build with token-scaling support."
             )
         },
     )

Add a validator (outside this hunk) to warn when ineffective:

# place near other @model_validator(mode="before")
@model_validator(mode="before")
@classmethod
def check_token_scaling_requires_flce(cls, data):
    if data.get("liger_use_token_scaling") and not data.get("liger_fused_linear_cross_entropy"):
        LOG.warning(
            "`liger_use_token_scaling: true` has no effect unless `liger_fused_linear_cross_entropy: true` is also set."
        )
    return data
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 0094a2d and 4aafd5f.

📒 Files selected for processing (3)
  • src/axolotl/integrations/liger/README.md (1 hunks)
  • src/axolotl/integrations/liger/args.py (2 hunks)
  • src/axolotl/integrations/liger/plugin.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/integrations/liger/args.py (1)

38-46: LGTM: adds well-scoped opt-in flag with clear schema.

src/axolotl/integrations/liger/plugin.py (1)

51-76: Add missing inspect import and verify patch logic

  • In src/axolotl/integrations/liger/plugin.py, add import inspect before using inspect.signature.
  • Ensure the guard skips the monkey-patch when use_token_scaling isn’t in the function or __init__ signature; manually test in an environment with Liger-Kernel installed to confirm no TypeError is raised.

@codecov
Copy link

codecov bot commented Sep 2, 2025

Codecov Report

❌ Patch coverage is 42.85714% with 12 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/integrations/liger/plugin.py 7.69% 12 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian
Copy link
Collaborator

winglian commented Sep 2, 2025

This should be a draft, right? Since it needs a new Liger release

@NanoCode012 NanoCode012 marked this pull request as draft September 2, 2025 13:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants